﻿using System;

namespace IOP.Models.Tree.BinaryTree
{
    /// <summary>
    /// AVL树
    /// </summary>
    /// <typeparam name="TKey"></typeparam>
    /// <typeparam name="TValue"></typeparam>
    public class AVLTree<TKey, TValue> : BinarySearchTree<TKey, TValue>
        where TKey :IComparable<TKey>
    {

        /// <summary>
        /// 插入
        /// </summary>
        /// <param name="key"></param>
        /// <param name="value"></param>
        public override void Put(TKey key, TValue value)
        {
            if (Root == null)
            {
                Root = new BinaryTreeNode<TKey, TValue>(key, value);
                Count++;
            }
            else Root = Put(Root, key, value);
        }
        /// <summary>
        /// 插入
        /// </summary>
        /// <param name="current"></param>
        /// <param name="key"></param>
        /// <param name="value"></param>
        /// <returns></returns>
        private BinaryTreeNode<TKey, TValue> Put(BinaryTreeNode<TKey, TValue> current, TKey key, TValue value)
        {
            if(current == null)
            {
                current = new BinaryTreeNode<TKey, TValue>(key, value);
                Count++;
            }
            else
            {
                int cmp = key.CompareTo(current.Key);
                if (cmp == 1)
                {
                    current.Right = Put(current.Right, key, value);
                    current.Right.Parent = current;
                    if(GetNodeHeight(current.Right) - GetNodeHeight(current.Left) == 2)
                    {
                        if (key.CompareTo(current.Right.Key) > 0) current = RRRotation(current);
                        else current = RLRotation(current);
                    }
                }
                else if (cmp == -1)
                {
                    current.Left = Put(current.Left, key, value);
                    current.Left.Parent = current;
                    if(GetNodeHeight(current.Left) - GetNodeHeight(current.Right) == 2)
                    {
                        if (key.CompareTo(current.Left.Key) < 0) current = LLRotation(current);
                        else current = LRRotation(current);
                    }
                }
                else if (cmp == 0) current.Value = value;
            }
            current.Height = Math.Max(GetNodeHeight(current.Left), GetNodeHeight(current.Right)) + 1;
            return current;
        }

        /// <summary>
        /// 删除
        /// </summary>
        /// <param name="key"></param>
        public override void Delete(TKey key)
        {
            Root = Delete(Root, key);
        }
        /// <summary>
        /// 删除
        /// </summary>
        /// <param name="current"></param>
        /// <param name="key"></param>
        /// <returns></returns>
        private BinaryTreeNode<TKey, TValue> Delete(BinaryTreeNode<TKey, TValue> current, TKey key)
        {
            if (current == null) return null;
            int cmp = key.CompareTo(current.Key);
            if(cmp == 1)
            {
                current.Right = Delete(current.Right, key);
                if(GetNodeHeight(current.Left) - GetNodeHeight(current.Right) == 2)
                {
                    var l = current.Left;
                    if (GetNodeHeight(l.Right) > GetNodeHeight(l.Left))
                        current = LRRotation(current);
                    else
                        current = LLRotation(current);
                }
            }
            else if (cmp == -1)
            {
                current.Left = Delete(current.Left, key);
                if(GetNodeHeight(current.Right) - GetNodeHeight(current.Left) == 2)
                {
                    var r = current.Right;
                    if (GetNodeHeight(r.Left) > GetNodeHeight(r.Right))
                        current = RLRotation(current);
                    else
                        current = RRRotation(current);
                }
            }
            else if (cmp == 0)
            {
                if(current.Left != null && current.Right != null)
                {
                    if (GetNodeHeight(current.Left) > GetNodeHeight(current.Right))
                    {
                        var max = FindMax(current.Left);
                        current.Key = max.Key;
                        current.Value = max.Value;
                        current.Left = Delete(current.Left, current.Key);
                    }
                    else
                    {
                        var min = FindMin(current.Right);
                        current.Key = min.Key;
                        current.Value = min.Value;
                        current.Right = Delete(current.Right, current.Key);
                    }
                }
                else
                {
                    if (current.Left == null)
                    {
                        if (current.Right != null) current.Right.Parent = current.Parent;
                        Count--;
                        return current.Right;
                    }
                    if (current.Right == null)
                    {
                        if (current.Left != null) current.Left.Parent = current.Parent;
                        Count--;
                        return current.Left;
                    }
                }
            }
            current.Height = Math.Max(GetNodeHeight(current.Left), GetNodeHeight(current.Right)) + 1;
            return current;
        }

        /// <summary>
        /// 左单旋转
        /// </summary>
        /// <param name="current"></param>
        private BinaryTreeNode<TKey, TValue> LLRotation(BinaryTreeNode<TKey, TValue> current)
        {
            if (current == null) return null;
            BinaryTreeNode<TKey, TValue> temp;
            temp = current.Left;
            if (temp != null) temp.Parent = current.Parent;

            current.Left = temp.Right;
            if(current.Left != null) current.Left.Parent = current;

            temp.Right = current;
            current.Parent = temp;

            current.Height = Math.Max(GetNodeHeight(current.Left), GetNodeHeight(current.Right)) + 1;
            temp.Height = Math.Max(GetNodeHeight(temp.Left), current.Height) + 1;

            return temp;
        }
        /// <summary>
        /// 右单旋转
        /// </summary>
        /// <param name="current"></param>
        /// <returns></returns>
        private BinaryTreeNode<TKey, TValue> RRRotation(BinaryTreeNode<TKey, TValue> current)
        {
            if (current == null) return null;
            BinaryTreeNode<TKey, TValue> temp
;
            temp = current.Right;
            if(temp != null) temp.Parent = current.Parent;

            current.Right = temp.Left;
            if(current.Right != null) current.Right.Parent = current;

            temp.Left = current;
            current.Parent = temp;

            current.Height = Math.Max(GetNodeHeight(current.Left), GetNodeHeight(current.Right)) + 1;
            temp.Height = Math.Max(GetNodeHeight(temp.Right), current.Height) + 1;

            return temp;
        }
        /// <summary>
        /// 左双旋转
        /// </summary>
        /// <param name="current"></param>
        /// <returns></returns>
        private BinaryTreeNode<TKey, TValue> LRRotation(BinaryTreeNode<TKey, TValue> current)
        {
            current.Left = RRRotation(current.Left);
            current = LLRotation(current);
            return current;
        }
        /// <summary>
        /// 右双旋转
        /// </summary>
        /// <param name="current"></param>
        /// <returns></returns>
        private BinaryTreeNode<TKey, TValue> RLRotation(BinaryTreeNode<TKey, TValue> current)
        {
            current.Right = LLRotation(current.Right);
            current = RRRotation(current);
            return current;
        }
    }
}
